import argparse, json, datetime
from pathlib import Path
import numpy as np

def read_D_map(path):
    data = json.loads(Path(path).read_text(encoding="utf-8"))
    if isinstance(data, dict):
        items = sorted([(float(k), float(v)) for k,v in data.items()], key=lambda t: t[0])
    elif isinstance(data, list):
        items = sorted([(float(e["n"]), float(e["D"])) for e in data], key=lambda t: t[0])
    else:
        raise ValueError("Unsupported D map format")
    ns = [n for n,_ in items]
    Ds = [D for _,D in items]
    return ns, Ds

def read_rates_csv(path):
    import csv
    rows = []
    with open(path, newline="", encoding="utf-8") as f:
        r = csv.DictReader(f)
        for row in r:
            rows.append(row)
    by_n = {}
    for row in rows:
        n = float(row["n"])
        i = int(row["i"]); j = int(row["j"]); p = float(row["p"])
        by_n.setdefault(n, {})
        by_n[n][(i,j)] = p
    P_by_n = {}
    for n, mp in by_n.items():
        max_i = max(i for (i,_) in mp.keys())
        max_j = max(j for (_,j) in mp.keys())
        dim = max(max_i, max_j) + 1
        P = np.zeros((dim, dim), dtype=float)
        for (i,j), p in mp.items():
            P[i,j] = p
        P_by_n[n] = P
    return P_by_n

def load_pivot_params(pivot_params_path):
    data = json.loads(Path(pivot_params_path).read_text(encoding="utf-8"))
    a = float(data["a"]); b = float(data["b"])
    g2 = a*2.0 + b
    if abs(g2 - 1.0) > 1e-6:
        raise ValueError(f"Constraint failed: g(2)={g2} must be 1 within 1e-6")
    if not (a < 0.0):
        raise ValueError(f"Constraint failed: slope a={a} must be < 0")
    return a, b

def g_of_D(D, a, b): return a*D + b

def build_tridiagonal_kernel(D, a, b, N):
    g = g_of_D(D, a, b)
    M = np.zeros((N, N), dtype=float)
    if N == 1:
        M[0,0] = D
        return M
    for i in range(N):
        M[i,i] = D - 2.0*g
        if i+1 < N: M[i,i+1] = g
        if i-1 >= 0: M[i,i-1] = g
    return M

def ensure_dir(p): Path(p).mkdir(parents=True, exist_ok=True)

def main():
    ap = argparse.ArgumentParser(description="Build per-context kernels M(n) (analytic tridiagonal or empirical rates).")
    ap.add_argument("--D", required=True, help="Path to D_of_n.json (map)")
    ap.add_argument("--outdir", default="results", help="Output folder for kernels & specs")
    ap.add_argument("--mode", choices=["analytic","empirical"], default="analytic")
    ap.add_argument("--rates", help="flip_rates_by_context.csv (long form: n,i,j,p). Required for empirical mode")
    ap.add_argument("--dim", type=int, default=5, help="State dimension N for analytic tridiagonal")
    ap.add_argument("--pivot_params", default="pivot_params.json", help="Path to pivot_params.json")
    ap.add_argument("--plots", action="store_true", help="Also save heatmaps for a few M(n)")
    args = ap.parse_args()

    ns, Ds = read_D_map(args.D)
    ensure_dir(args.outdir)
    kernels_dir = Path(args.outdir) / "kernels"
    ensure_dir(kernels_dir)

    spec = {
        "build_mode": args.mode,
        "mapping": "empirical: M(n)=D(n)·P(n); analytic: tridiagonal with off-diagonals g(D)",
        "dim": args.dim,
        "pivot_params": args.pivot_params,
        "D_source": args.D,
        "rates_source": args.rates if args.rates else None,
        "created_utc": datetime.datetime.utcnow().isoformat() + "Z"
    }

    if args.mode == "empirical":
        if not args.rates:
            raise SystemExit("--rates is required for empirical mode")
        P_by_n = read_rates_csv(args.rates)
    else:
        a,b = load_pivot_params(args.pivot_params)

    kernel_index = []
    for n, D in zip(ns, Ds):
        if args.mode == "analytic":
            M = build_tridiagonal_kernel(D, a, b, args.dim)
        else:
            if n not in P_by_n:
                raise SystemExit(f"No rates found for n={n} in {args.rates}")
            P = P_by_n[n]
            rowsums = P.sum(axis=1)
            if not np.allclose(rowsums, 1.0, atol=1e-12):
                raise SystemExit(f"Row sums not ~1 for n={n}: max dev={np.max(np.abs(rowsums-1.0))}")
            M = float(D) * P
        out_path = kernels_dir / f"M_n={n:g}.npy"
        np.save(out_path, M)
        kernel_index.append({"n": n, "path": str(out_path), "shape": list(M.shape)})

    (Path(args.outdir) / "kernel_index.json").write_text(json.dumps(kernel_index, indent=2), encoding="utf-8")
    (Path(args.outdir) / "kernel_specs.json").write_text(json.dumps(spec, indent=2), encoding="utf-8")

    if args.plots:
        import matplotlib.pyplot as plt
        for ent in kernel_index[:min(4, len(kernel_index))]:
            M = np.load(ent["path"])
            plt.figure()
            plt.imshow(M, aspect='equal')
            plt.colorbar()
            plt.title(f"Heatmap M(n={ent['n']})")
            png = Path(args.outdir) / f"heatmap_M_n={ent['n']}.png"
            plt.savefig(png, dpi=160, bbox_inches="tight")
            plt.close()

    print(f"Built {len(kernel_index)} kernels into {kernels_dir}")

if __name__ == "__main__":
    main()
